% Copyright (C) 2019-2023 Benjamin Born, Francesco D'Ascanio, Gernot J. Mueller, Johannes Pfeifer
%
% This is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% It is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
% 
% For a copy of the GNU General Public License,
% see <http://www.gnu.org/licenses/>.

function run_LP_wages(sign_iter)

if sign_iter==1
    sign_dummy = -1;
elseif sign_iter==2
    sign_dummy = 1;
else
    error('Case not implemented')
end
addpath('../Auxiliary_Files')
fontsize = 6;

figures_name = 'Figures';

lag_number=4;
max_irf_horizon = 8;
time_fixed_effects_dummy=1;
country_fixed_effects_dummy=1;

add_residual_dummy=1; %adds residuals from previous step to regression to increase efficiency;


%the naming of the following needs to be consistent with the field names in the pos structure

impulse_name = 'PF_shock_G';

special_name = 'pf_fc';


%% Construct sample

% Load data arrays for regression
load('dataset_BDMP.mat')
temp=load('../Sign_restriction_first_stage_shocks_euro_area/results/structural_shocks_point.mat');

data_array_for_regression_stacked_by_variable = cat(3,data_array_for_regression_stacked_by_variable,temp.structural_shocks.BP_G,temp.structural_shocks.BP_T,temp.structural_shocks.PF_G,temp.structural_shocks.PF_T);
pos.BP_shock_G_CK=length(fieldnames(pos))+1;
Header{1,pos.BP_shock_G_CK}='BP_shock G based on Caldara/Kamps';
pos.BP_shock_T_CK=length(fieldnames(pos))+1;
Header{1,pos.BP_shock_T_CK}='BP_shock T based on Caldara/Kamps';
pos.PF_shock_G=length(fieldnames(pos))+1;
Header{1,pos.PF_shock_G}='Sign restriction shock G based on Caldara/Kamps';
pos.PF_shock_T=length(fieldnames(pos))+1;
Header{1,pos.PF_shock_T}='Sign restriction shock T based on Caldara/Kamps';

%% get wage_growth_negotiated_wages

temp=load('negotiated_wages_test.mat');

var_string={'wage_growth_negotiated_wages'};
for var_iter=1:length(var_string)
    output_series=NaN(size(basic_timeline,1),size(country_header,2),1+temp.lagnumber+temp.leadnumber);
    for country_iter=1:size(country_header,2)
        country_pos=strmatch(country_header{1,country_iter},temp.country_names,'exact');
        if isempty(country_pos)
            fprintf('Missing country %s\n',country_header{1,country_iter})
        else
            input_matrix=temp.data_mat(:,country_pos,[temp.pos.(var_string{var_iter}) ...
                temp.pos.([var_string{var_iter} '_lag_1']):temp.pos.([var_string{var_iter} '_lag_' num2str(temp.lagnumber)])...
                temp.pos.([var_string{var_iter} '_lead_1']):temp.pos.([var_string{var_iter} '_lead_' num2str(temp.leadnumber)])]);
            output_series(basic_timeline-temp.timeline_panel(1)>-1e-10 & basic_timeline-temp.timeline_panel(end)<1e-10,country_iter,:)=...
                input_matrix(temp.timeline_panel-basic_timeline(1)>-1e-10 & temp.timeline_panel-basic_timeline(end)<1e-10,:);
        end
    end

    data_array_for_regression_stacked_by_variable = cat(3,data_array_for_regression_stacked_by_variable,output_series);
    pos.(var_string{var_iter})=length(fieldnames(pos))+1;
    Header{1,pos.(var_string{var_iter})}=var_string{var_iter};

    for iter=1:temp.lagnumber
        pos.([var_string{var_iter} '_lag_' num2str(iter)])=length(fieldnames(pos))+1;
        Header{1,pos.([var_string{var_iter} '_lag_' num2str(iter)])}=[var_string{var_iter} '_lag_' num2str(iter)];
    end
    for iter=1:temp.leadnumber
        pos.([var_string{var_iter} '_lead_' num2str(iter)])=length(fieldnames(pos))+1;
        Header{1,pos.([var_string{var_iter} '_lead_' num2str(iter)])}=[var_string{var_iter} '_lead_' num2str(iter)];
    end
    %do cumsum
    output_series=NaN(size(basic_timeline,1),size(country_header,2),1+temp.leadnumber);
    for country_iter=1:size(country_header,2)
        country_pos=strmatch(country_header{1,country_iter},temp.country_names,'exact');
        if isempty(country_pos)
            fprintf('Missing country %s\n',country_header{1,country_iter})
        else
            input_matrix=temp.data_mat(:,country_pos,[temp.pos.([var_string{var_iter} '_cumsum_0']):temp.pos.([var_string{var_iter} '_cumsum_' num2str(temp.leadnumber)])]);
            output_series(basic_timeline-temp.timeline_panel(1)>-1e-10 & basic_timeline-temp.timeline_panel(end)<1e-10,country_iter,:)=...
                input_matrix(temp.timeline_panel-basic_timeline(1)>-1e-10 & temp.timeline_panel-basic_timeline(end)<1e-10,:);
        end
    end
    data_array_for_regression_stacked_by_variable = cat(3,data_array_for_regression_stacked_by_variable,output_series);
    for iter=0:temp.leadnumber
        pos.([var_string{var_iter} '_cumsum_' num2str(iter)])=length(fieldnames(pos))+1;
        Header{1,pos.([var_string{var_iter} '_cumsum_' num2str(iter)])}=[var_string{var_iter} '_cumsum_' num2str(iter)];
    end

end

if length(Header)~=length(fieldnames(pos)) ||  length(Header)~=size(data_array_for_regression_stacked_by_variable,3)
    error('Header is not correctly defined')
end

% check whether positions are unique
pos_numbers=cell2mat(struct2cell(pos));
unique_pos_numbers=unique(pos_numbers);
if ~isequal(pos_numbers,unique_pos_numbers)
    pos_numbers(~ismember(pos_numbers,unique_pos_numbers))
    error('The position numbers are wrong')
end
% Set dependent variables and regressors

regressor_var_names={'g_real_demeaned_linearly_detrended', 'y_real_demeaned_linearly_detrended', 'Effective_FX_Real_Intra_Euro_CPI_log', 'tax_revenue_real_demeaned_detrended','wage_growth_negotiated_wages' };
dependent_var_names={'g_real_demeaned_linearly_detrended','y_real_demeaned_linearly_detrended','Effective_FX_Real_Intra_Euro_CPI_log', 'tax_revenue_real_demeaned_detrended','wage_growth_negotiated_wages_cumsum'};
special_name = ['real_fx_intraeuro_',special_name];
dependent_var_plottitles={'Government consumption','GDP','Real Effective Exchange Rate', 'Tax revenues','Negotiated wages'};
dependent_var_ylabels={'percent','percent','percent', 'percent','percent'};

if lag_number<1
    error('Lag number must be strictly positiv0e')
end


% euro countries only
[row,col] = find(data_array_for_regression_stacked_by_variable(:,:,pos.euro_country)==0);
split_save_name='eur';
Figure_name_split='Euro Countries';
    
for ii=1:length(row)
    data_array_for_regression_stacked_by_variable(row(ii),col(ii),3:end)=NaN;
end
clear row col;

%% Shock Variable

fe_shocks_temp = data_array_for_regression_stacked_by_variable(:,:,pos.(impulse_name));

fe_shocks_negative = zeros(size(fe_shocks_temp));
fe_shocks_negative(fe_shocks_temp<0 | isnan(fe_shocks_temp)) = fe_shocks_temp(fe_shocks_temp<0 | isnan(fe_shocks_temp));
fe_shocks_positive = zeros(size(fe_shocks_temp));
fe_shocks_positive(fe_shocks_temp>0 | isnan(fe_shocks_temp)) = fe_shocks_temp(fe_shocks_temp>0 | isnan(fe_shocks_temp));
if any(any(fe_shocks_negative>0)) || any(any(fe_shocks_positive<0 ))
    error('Sign is wrong')
end

if sign_dummy==-1
    data_array_for_regression_impulse_only=cat(3,fe_shocks_negative,fe_shocks_positive);
elseif sign_dummy==1
    data_array_for_regression_impulse_only=cat(3,fe_shocks_positive,fe_shocks_negative);
elseif sign_dummy==0
    data_array_for_regression_impulse_only=cat(3,fe_shocks_temp);
else
    error('undefined case for sign_dummy')
end

data_array_for_regression=data_array_for_regression_impulse_only;

%% Lagged regressors

for var_iter=1:length(regressor_var_names)
    data_array_control=data_array_for_regression_stacked_by_variable(:,:,pos.([regressor_var_names{var_iter},'_lag_1']):pos.([regressor_var_names{var_iter},'_lag_',num2str(lag_number)]));
    data_array_for_regression=cat(3,data_array_for_regression,data_array_control);
end

%% Run regression

theta_mat=NaN(1+max_irf_horizon,1,length(dependent_var_names));
se_mat=NaN(1+max_irf_horizon,1,length(dependent_var_names));
theta_table = NaN(8, 15);

for var_iter=1:length(dependent_var_names)
    for horizon_iter = 0:max_irf_horizon
        %construct matrices without NaN
        if horizon_iter==0
            if var_iter==5
                dependent_variable=squeeze(data_array_for_regression_stacked_by_variable(:,:,pos.([dependent_var_names{var_iter},'_',num2str(horizon_iter)])));
            else
                dependent_variable=squeeze(data_array_for_regression_stacked_by_variable(:,:,pos.(dependent_var_names{var_iter})));
            end
            yhat_full=NaN(size(dependent_variable));
            [dependent_variable_for_run,data_array_for_regression_for_run,time_indices_non_NaN,country_indices_non_NaN]=create_regression_matrices_no_NaN(dependent_variable,data_array_for_regression,data_array_for_regression_stacked_by_variable,pos,country_indicator_names_mapping,time_fixed_effects_dummy,country_fixed_effects_dummy);
            %run regression
            [theta,stdDK,~,CovDK,yhat] = HszDk5cPs(dependent_variable_for_run,ones(size(dependent_variable_for_run,1),1),data_array_for_regression_for_run,1,7,1);
            
            if add_residual_dummy && ~strcmp(impulse_name,dependent_var_names(var_iter)) %when regression G on itself, residuals are 0
                yhat_full(time_indices_non_NaN,country_indices_non_NaN)=yhat;
                resids=dependent_variable-yhat_full;
            else
                resids=[];
            end
        else
            if var_iter==5
                dependent_variable=squeeze(data_array_for_regression_stacked_by_variable(:,:,pos.([dependent_var_names{var_iter},'_',num2str(horizon_iter)])));
            else
                dependent_variable=squeeze(data_array_for_regression_stacked_by_variable(:,:,pos.([dependent_var_names{var_iter},'_lead_',num2str(horizon_iter)])));
            end
            yhat_full=NaN(size(dependent_variable));
            
            [dependent_variable_for_run,data_array_for_regression_for_run,time_indices_non_NaN,country_indices_non_NaN]=create_regression_matrices_no_NaN(dependent_variable,cat(3,data_array_for_regression,resids),data_array_for_regression_stacked_by_variable,pos,country_indicator_names_mapping,time_fixed_effects_dummy,country_fixed_effects_dummy);
            %run regression
            [theta,stdDK,~,CovDK,yhat] = HszDk5cPs(dependent_variable_for_run,ones(size(dependent_variable_for_run,1),1),data_array_for_regression_for_run,1,7,1);
            if add_residual_dummy
                yhat_full(time_indices_non_NaN,country_indices_non_NaN)=yhat;
                resids=dependent_variable-yhat_full;
            else
                resids=[];
            end
        end
        
        %save coefficients
        theta_mat(1+horizon_iter,:,var_iter) = theta(1,1);
        se_mat(1+horizon_iter,:,var_iter)  = stdDK(1,1)';
    end
end

%% Plot IRFs

if length(dependent_var_names) == 3
    if sign_dummy == -1
        subplot_index = [1,4,7];
    elseif sign_dummy == 1
        subplot_index = [2,5,8];
    else
        error('not defined')
    end
    startsp = 1;
elseif length(dependent_var_names) == 4
    if sign_dummy == -1 || sign_dummy == 0
        subplot_index = [1,4,7,10,];
    elseif sign_dummy == 1
        subplot_index = [2,5,8,11];
    else
        error('not defined')
    end
    startsp = 1;
    elseif length(dependent_var_names) == 5
    if sign_dummy == -1 || sign_dummy == 0
        subplot_index = [1,4,7,10,13];
    elseif sign_dummy == 1
        subplot_index = [2,5,8,11,14];
    else
        error('not defined')
    end
    startsp = 1;
else
    error('Specify Subplot Layout')
end


if sign_dummy == -1
%     scaling = -0.25; % cut in government spending
    scaling = -100; % cut in government spending
    sign_plottitle = 'Gov. consumption cut';
elseif sign_dummy == 1
%     scaling = 0.25; % increase in government spending
    scaling = 100; % increase in government spending
    sign_plottitle = 'Gov. consumption hike';
elseif sign_dummy == 0
%     scaling = 0.25; % increase in government spending
    scaling = 100; % increase in government spending
    sign_plottitle = 'Gov. consumption cut (symmetric model)';
end

nptsvar=max_irf_horizon;
confidence_90  = norminv(0.9 + (1 - 0.9) / 2, 0, 1);
confidence_68  = norminv(0.68 + (1 - 0.68) / 2, 0, 1);

if sign_dummy == -1 || sign_dummy == 0
    main_fig = figure;
elseif sign_dummy == 1
    main_fig = openfig([figures_name,filesep,split_save_name,'_asym_neg_',special_name]);
else
    error('not defined yet')
end
    
plot_iter = 1;
for sp = startsp:length(dependent_var_names)
    if sp == 3
        scaling = scaling * -1; %account for definition of FX in Eruostat data, which is the opposite of the one in the paper
    end
    figure(main_fig)

    subplot(5,3,subplot_index(plot_iter))

    ci1b_90 = scaling*(theta_mat(:,1,sp)+confidence_90*se_mat(:,1,sp));
    ci2b_90 = scaling*(theta_mat(:,1,sp)-confidence_90*se_mat(:,1,sp));
    ci1b_68 = scaling*(theta_mat(:,1,sp)+confidence_68*se_mat(:,1,sp));
    ci2b_68 = scaling*(theta_mat(:,1,sp)-confidence_68*se_mat(:,1,sp));

    topb_90 = max(ci1b_90,ci2b_90);
    bottomb_90 = min(ci1b_90,ci2b_90);
    topb_68 = max(ci1b_68,ci2b_68);
    bottomb_68 = min(ci1b_68,ci2b_68);

    ha1_90 = area(0:nptsvar,[bottomb_90, topb_90-bottomb_90],'FaceColor',[204/255 229/255 1],'EdgeColor','none','ShowBaseLine','off');
    set(ha1_90(1), 'FaceColor', 'none') % this makes the bottom area invisible
    set(ha1_90, 'LineStyle', '-')
    hold on
    ha1_68 = area(0:nptsvar,[bottomb_68, topb_68-bottomb_68],'FaceColor',[153/255 204/255 1],'EdgeColor','none','ShowBaseLine','off');
    set(ha1_68(1), 'FaceColor', 'none') % this makes the bottom area invisible
    set(ha1_68, 'LineStyle', '-')
    hold on

    fs=plot(0:nptsvar,scaling*theta_mat(:,1,sp),'b-', 'LineWidth', 2.5);
    
    hline(0,'k:')
    xlim([0 max_irf_horizon])
    box on;set(gca,'xTick',0:max_irf_horizon,'Layer','top','FontSize',fontsize);
    
    title(dependent_var_plottitles(sp),'FontSize',fontsize)
    ylabel(dependent_var_ylabels(sp),'FontSize',fontsize)
    xlabel('quarters','FontSize',fontsize)
    plot_iter = plot_iter + 1;
    if sp == 3
        scaling = scaling * -1;
    end
end

set(findall(main_fig,'-property','ShowBaseLine'),'ShowBaseLine','off')


% save main figure
if sign_dummy==-1
    saveas(main_fig,[figures_name,filesep,split_save_name,'_asym_neg_',special_name]);
    print([figures_name,filesep,split_save_name,'_asym_neg_',special_name],'-depsc2')
    close(main_fig)
elseif sign_dummy==1
    set(main_fig,'Units','Inches');
    pos = get(main_fig,'Position');
    set(main_fig,'PaperPositionMode','Auto','PaperUnits','Inches','PaperSize',[pos(3), pos(4)])
    saveas(main_fig,[figures_name,filesep,split_save_name,'_asym_pos_',special_name]);
    print([figures_name,filesep,split_save_name,'_asym_pos_',special_name],'-dpdf')
elseif sign_dummy==0
    saveas(main_fig,[figures_name,filesep,split_save_name,'_',indicator_save_name,'_sym_',special_name]);
    %     print(['Figures',filesep,split_save_name,'_',indicator_save_name,'_sym'],'-dpdf')
    print([figures_name,filesep,split_save_name,'_',indicator_save_name,'_sym_',special_name],'-depsc2')
end
